{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo: Extract intermediate representations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Define a random model\n",
    "e.g., ResNet-18 in torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models\n",
    "\n",
    "model = models.resnet18(pretrained=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Define a forward hook manager\n",
    "Let's say we want to extract representations from the following layers (modules) in [ResNet-18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L230-L249):  \n",
    "- \"conv1\": input\n",
    "- \"layer1.0.bn2\": input and output\n",
    "- \"fc\": output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdistill.core.forward_hook import ForwardHookManager\n",
    "\n",
    "device = torch.device('cpu')\n",
    "forward_hook_manager = ForwardHookManager(device)\n",
    "forward_hook_manager.add_hook(model, 'conv1', requires_input=True, requires_output=False)\n",
    "forward_hook_manager.add_hook(model, 'layer1.0.bn2', requires_input=True, requires_output=True)\n",
    "forward_hook_manager.add_hook(model, 'fc', requires_input=False, requires_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Execute forward function of the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that ResNet-18's input shape for ImageNet dataset is 3x224x224.  \n",
    "Here we use batch size of 32 for a random input batch *x*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.rand(32, 3, 224, 224)\n",
    "y = model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Get I/O dictionary from the manager\n",
    "There should be three keys: \"conv1\", \"layer1.bn2\", and \"fc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['conv1', 'layer1.0.bn2', 'fc'])\n"
     ]
    }
   ],
   "source": [
    "io_dict = forward_hook_manager.pop_io_dict()\n",
    "print(io_dict.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Make sure that input to \"conv1\" matches *x*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['input'])\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(io_dict['conv1'].keys())\n",
    "conv1_input = io_dict['conv1']['input']\n",
    "print(torch.equal(x, conv1_input))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Similarly, make sure that output from \"fc\" matches *y*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['output'])\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(io_dict['fc'].keys())\n",
    "fc_output = io_dict['fc']['output']\n",
    "print(torch.equal(y, fc_output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Check if the extracted input/output tensor from \"layer1.0.bn2\"..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['input', 'output'])\n"
     ]
    }
   ],
   "source": [
    "print(io_dict['layer1.0.bn2'].keys())\n",
    "layer1_bn2_input = io_dict['layer1.0.bn2']['input']\n",
    "layer1_bn2_output = io_dict['layer1.0.bn2']['output']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check if the extracted tensors match those used in the model, you'll need to reimplement [ResNet-18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L230-L249) to change the interface of forward function in **ResNet** and **BasicBlock** classes so that you can get the intermediate tensors in addition to *y* from ***model*** shown in Section 4.  \n",
    "You don't want to, right? That's why you want to use **ForwardHookManager** instead."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}